""" Create splits of the data into train and test data used for cross-validation """

import random

from pySPACE.missions.nodes.base_node import BaseNode
from pySPACE.tools.memoize_generator import MemoizeGenerator
import logging

class CrossValidationSplitterNode(BaseNode):
    """ Perform (stratified) cross-validation
    
    During benchmarking, n pairs of training and test data are generated, where
    n is configurable via the parameter splits. The n test datasets are pairwise
    disjunct. Internally, the available data is partitioned into n pairwise 
    disjunct sets s_1, ..., s_n of equal size (the "splits"). The i-th pair of 
    training and test data is generated by using s_i as test data and the 
    union of the remaining datasets as training data.
    
    The partitioning is stratified per default, i.e. the splits have the same 
    class ratio as the overall dataset. Per default, the partitioning is based 
    on shuffling the data randomly. In this case, the partitioning of the data 
    into s_1, ..., s_n is determined solely based on the run number (used as 
    random seed), yielding the same split for the same run_number and different 
    ones for two different run_numbers.
    
    **Parameters**
    
      :splits:
            The number of splits created internally. If n data points exist and
            m splits are created, each of these splits consists of approx. m/n
            data points. 
            
            (*optional, default: 10*)
        
      :stratified:
         If true, the cross-validation is stratified, i.e. the overall 
         class-ratio is retained in each split (as good as possible). 
         
         (*optional, default: True*)
         
      :random:
         If true, the order of the data is randomly shuffled. 
         
         (*optional, default: True*)
         
      :time_dependent:
         If True splitting is done separately for different (= not 
         overlapping) time windows to ensure that instances corresponding to the
         same marker will be in the same split.
         
         .. note:: Stratification is only allowed here if there is only one 
                   class label for one marker.
         
         (*optional, default: False*)

      :stratified_class:
         
         If *time_dependent* is True and *stratified_class* is specified 
         stratification is only done for the specified class label (String).
         The other class is filling the split preserving the time order of the 
         data. This also means that *random* has no effect here.

         (*optional, default: None*)

    **Exemplary Call**
    
    .. code-block:: yaml
    
        -
            node : CV_Splitter
            parameters :
                  splits : 10
                  stratified : True
    
    :Author: Jan Hendrik Metzen (jhm@informatik.uni-bremen.de)
    :Created: 2008/12/16
    """
    
    def __init__(self,  splits=10, stratified=True, random=True,
                 time_dependent=False, stratified_class = None,  *args, **kwargs):
        super(CrossValidationSplitterNode, self).__init__(*args, **kwargs)
        
        self.set_permanent_attributes(splits = int(splits), #how many splits
                                      current_split = 0, # current split for testing
                                      split_indices = None,
                                      run_number = -1,
                                      random = random,
                                      stratified = stratified,
                                      stratified_class = stratified_class,
                                      time_dependent = time_dependent)

    def is_split_node(self):
        """ Return whether this is a split node """
        return True

    def use_next_split(self):
        """ Use the next split of the data into training and test data.
        
        Returns True if more splits are available, otherwise False.
        
        This method is useful for benchmarking
        """
        if self.current_split + 1 < self.splits:
            self.current_split = self.current_split + 1
            self._log("Benchmarking with split %s/%s" % (self.current_split + 1,
                                                         self.splits))
            return True
        else:
            return False
    
    def train_sweep(self, use_test_data):
        """ Performs the actual training of the node.
        
        .. note:: Split nodes cannot be trained
        """
        raise Exception("Split nodes cannot be trained")
        
    def request_data_for_training(self, use_test_data):
        """ Returns the data for training of subsequent nodes

        .. todo:: to document
        """
        # Create cv-splits lazily when required
        if self.split_indices == None:
            self._create_splits()
            
        # All data can be used for training which is not explicitly
        # specified for testing by the current cv-split
        self.data_for_training = MemoizeGenerator(
                self.data[i] for i in range(len(self.data)) 
                    if not i in self.split_indices[self.current_split])
        
        return self.data_for_training.fresh()
    
    def request_data_for_testing(self):
        """ Returns the data for testing of subsequent nodes

        .. todo:: to document
        """
        # Create cv-splits lazily when required
        if self.split_indices == None:
            self._create_splits()
        
        # Only that data can be used for testing which is explicitly
        # specified for this purpose by the current cv-split
        self.data_for_testing = MemoizeGenerator(
                self.data[i] for i in self.split_indices[self.current_split])
        
        return self.data_for_testing.fresh()

    def _create_splits(self):
        """ Create the split of the data for n-fold  cross-validation """
        self._log("Creating %s splits for cross validation" % self.splits)
                  
        # Get training and test data (with labels)
        train_data = \
          list(self.input_node.request_data_for_training(use_test_data=False))
        test_data = list(self.input_node.request_data_for_testing())
        
        # If there is already a non-empty training set, 
        # it means that we are not the first split node in the node chain
        if len(train_data) > 0:
            raise Exception("No iterated splitting of data sets allowed\n "
                            "(Calling a splitter on a data set that is "
                            "already split)")
        
        # Remember all the data and store it in memory
        # TODO: This might cause problems for large dataset
        self.data = train_data + test_data
        
        # initialize result structure: Determine which data points are 
        # reserved for testing in which cross validation run
        split_indices = []
        if self.time_dependent:

            # sort the data according to start_time
            self.data.sort(key=lambda swindow: swindow[0].start_time)
            # divide the data with respect to the time_point
            data_time = dict()
            last_window_end_time = 0.0
            marker = -1
            label_marker = dict()
            for (index, (window, label)) in enumerate(self.data):
                if window.start_time > last_window_end_time:
                    marker += 1
                    data_time[marker] = [index]
                    if self.stratified or self.stratified_class:
                        if label not in label_marker:
                            label_marker[label] = [marker]
                        else:
                            label_marker[label].append(marker)
                else:
                    data_time[marker].append(index)
                    # check label consistency for later stratification
                    if (self.stratified or self.stratified_class) and \
                                  self.data[data_time[marker][0]][1] != label:
                        import warnings
                        warnings.warn(
                            "Since there are several class labels"
                            " for one marker stratification is set to False.",
                            UserWarning)
                        self.stratified = False
                        self.stratified_class = None
                last_window_end_time = window.end_time
            #print "data_time: \n", data_time

            if self.stratified: # each marker has only one label
                # not more splits then markers of every class!
                assert(min([len(markers) for markers in
                            label_marker.values()]) >= self.splits)
                # extend result structure since we need it in the next block
                split_indices = [[] for i in range(self.splits)]
                # determine the splits of the data    
                for label, markers in label_marker.iteritems():
                    data_size = len(markers)
                    # Set random seed and randomize the order of the data
                    if self.random:
                        r = random.Random(self.run_number)
                        r.shuffle(markers)
                    for j in range(self.splits):
                        split_start = int(round(float(j) * data_size/self.splits))
                        split_end = int(round(float(j+1) * data_size/self.splits))
                        # means half-open interval [split_start, split_end)
                        for i in range(split_start, split_end):
                            split_indices[j].extend(data_time[markers[i]])
                # avoid sorted labels by sorting time dependent
                split_indices = [sorted(split_list)
                                 for split_list in split_indices]
                #print "run_number:", self.run_number    
                #print "time_dependent && stratified:\n", split_indices
            
            elif self.stratified_class:
                # extend result structure since we need it in the next block
                split_indices = [[] for i in range(self.splits)]
                # determine the splits of the data
                data_size = len(label_marker[self.stratified_class])

                for j in range(self.splits):
                    split_start = int(round(float(j) * data_size/self.splits))
                    split_end = int(round(float(j+1) * data_size/self.splits))
                    # means half-open interval [split_start, split_end)
                    for i in range(split_start, split_end):
                        split_indices[j].extend(data_time[label_marker[self.stratified_class][i]])
                #print "time_dependent && stratified_class:\n before filling up\n", split_indices        
                # fill up with other classes
                last_max_index = 0
                for split_list in split_indices:
                    max_index = max(split_list)
                    for i in range(last_max_index, max_index):
                        if self.data[i][1] != self.stratified_class:
                            split_list.append(i)
                    last_max_index = max_index+1
                for i in range(last_max_index, len(self.data)):
                    if self.data[i][1] != self.stratified_class:
                        split_indices[-1].append(i)
                # avoid sorted labels by sorting time dependent
                split_indices = [sorted(split_list)
                                 for split_list in split_indices]
                print "time_dependent && stratified_class:\n", split_indices
            else:
                # we should not have more splits then (marker)time points
                data_size = len(data_time.keys())
                assert(data_size >= self.splits)
            
                # Set random seed and randomize the order of the data
                indices = data_time.keys()
                if self.random:
                    r = random.Random(self.run_number)
                    r.shuffle(indices)
                
                # determine the splits of the data    
                for i in range(self.splits):
                    split_indices.append([])
                    split_start = int(round(float(i) * data_size / self.splits))
                    split_end = int(round(float(i + 1) * data_size / self.splits))
                    # means half-open interval [split_start, split_end)
                    for j in range(split_start,split_end):
                        split_indices[i].extend(data_time[indices[j]])
                # avoid sorted labels by sorting time dependent
                split_indices = [sorted(split_list)
                                 for split_list in split_indices]
                #for index, splitlist in enumerate(split_indices):
                #    print index, "first: ", self.data[splitlist[0]][0].start_time, ", last: ", self.data[splitlist[-1]][0].start_time, ", Laenge: ", len(data_time.keys()) 
                #print "time_dependent:\n", split_indices


        elif self.stratified: # Stratified cross-validation
            # divide the data with respect to the class_label 
            data_labeled = dict()
            for (index, (window, label)) in enumerate(self.data):
                if not data_labeled.has_key(label):
                    data_labeled[label] = [index]
                else:
                    data_labeled[label].append(index)
            
            # we should not have more splits then instances of every class!
            min_nr_per_class = min([len(data) for data in data_labeled.values()])
            if self.splits > min_nr_per_class:
                self.splits = min_nr_per_class
                self._log("Reducing number of splits to %s since no more "
                          "instances of one of the classes are available." 
                          % self.splits, level=logging.CRITICAL)
            # extend result structure since we need it in the next block
            split_indices = [[] for i in range(self.splits)]
            # determine the splits of the data    
            for label, indices in data_labeled.iteritems():
                data_size = len(indices)
                # Set random seed and randomize the order of the data
                if self.random:
                    r = random.Random(self.run_number)
                    r.shuffle(indices)
                for j in range(self.splits):
                    split_start = int(round(float(j) * data_size/self.splits))
                    split_end = int(round(float(j+1) * data_size/self.splits))
                    # means half-open interval [split_start, split_end)
                    split_indices[j].extend(indices[split_start: split_end])
            # avoid sorted labels
            for j in range(self.splits):
                r = random.Random(self.run_number)
                r.shuffle(split_indices[j])
            # print "stratified:\n", split_indices

            # old trunk version
            # =================
            # data_size = len(self.data)
            # # Determine ratio of class1
            # instance_labels = map(lambda x: x[1], self.data)
            # classes = list(set(instance_labels))
            # assert (len(classes) == 2),\
            #        "Stratified cross-validation works currently only for "\
            #        "binary classification tasks."
            # class1_instances = instance_labels.count(classes[0])
            # class2_instances = instance_labels.count(classes[1])
            
            # if self.splits > min(class1_instances, class2_instances):
            #     self.set_permanent_attributes(splits = min(class1_instances, 
            #                                                class2_instances))
            #    self._log("Reducing number of splits to %s since no more " \
            #              "instances of one of the classes are available." 
            #              % self.splits)
                    
            # class1_ratio = float(class1_instances) / data_size
            # # Determine which instances belong to which class
            # class1_indices = []
            # class2_indices = []
            # for index, instance_label in enumerate(instance_labels):
            #     if instance_label == classes[0]:
            #         class1_indices.append(index)
            #     else:
            #         class2_indices.append(index)
            # 
            # # Randomize order
            # if self.random:
            #     r = random.Random(self.run_number)
            #     r.shuffle(class1_indices)
            #     r.shuffle(class2_indices)
            #
            # # Merge the two classes (such that they alternate in the appropriate
            # # frequency)
            # indices = []
            # n = 0 # class1 counter
            # for i in range(data_size):
            #     if i == round((n + 0.5) / class1_ratio):
            #         indices.append(class1_indices.pop())
            #         n += 1
            #     else: 
            #         indices.append(class2_indices.pop())

        else:  # Non-stratified cross-validation
            data_size = len(self.data)
            # We cannot have more splits than data points
            assert(data_size >= self.splits) 
    
            # Set random seed and randomize the order of the data
            indices = range(data_size)
            if self.random:
                r = random.Random(self.run_number)
                r.shuffle(indices)
                
            # Determine the splits of the data
            for i in range(self.splits):
                split_start = int(round(float(i) * data_size / self.splits))
                split_end = int(round(float(i + 1) * data_size / self.splits))
                # means half-open interval [split_start, split_end)
                split_indices.append(indices[split_start: split_end]) 

        self.split_indices = split_indices
        
        self._log("Benchmarking with split %s/%s" % (self.current_split + 1,
                                                     self.splits))


_NODE_MAPPING = {"CV_Splitter": CrossValidationSplitterNode}
